feat: cute dsl mmfp4 for blackwell#2540
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds "cute-dsl" as a new mm_fp4 FP4 GEMM backend: benchmarks and tests updated; mm_fp4 dispatch extended with CuTe DSL availability checks, runner factory entry, and a kernel cache; and a new SM100 block‑scaled persistent GEMM kernel was added. Changes
Sequence DiagramsequenceDiagram
participant User as User
participant MM as mm_fp4
participant Disp as Dispatcher
participant Req as Requirement\r\n(_cute_dsl_gemm_fp4_requirement)
participant Run as Runner\r\n(_cute_dsl_gemm_fp4_runner)
participant Cache as Kernel\r\nCache
participant Kernel as CuTeDSL\r\nKernel
User->>MM: mm_fp4(..., backend="cute-dsl", enable_pdl=...)
MM->>Disp: select backend runner
Disp->>Req: validate availability & constraints
alt invalid
Req-->>MM: raise/skip
else valid
Disp->>Run: create/obtain runner
Run->>Cache: lookup compiled kernel by config
alt cached
Cache-->>Run: return kernel
else not cached
Run->>Kernel: compile kernel
Kernel-->>Cache: store compiled kernel
Cache-->>Run: return kernel
end
Run-->>Disp: runner instance
Disp->>Kernel: execute kernel with tensors
Kernel-->>User: results
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates cute_dsl as a new backend for mm_fp4, which is a significant and valuable addition. The changes are well-structured, introducing new, high-performance kernels ported from NVIDIA's libraries and integrating them consistently with the existing backend infrastructure. The new code is complex but appears to be of high quality. I've identified one potential issue in the autotuner's tactic generation logic where an alignment check seems to be incorrect, which could lead to suboptimal kernel selection. Overall, this is an excellent contribution that should improve FP4 GEMM performance.
| if swap_ab and not m_aligned: | ||
| continue |
There was a problem hiding this comment.
The alignment check for the output matrix C when swap_ab is true appears to be incorrect. When swap_ab is true, the kernel computes B.T @ A.T, and the output is effectively a column-major matrix of shape (n, m). The contiguous dimension in memory is along the columns, which corresponds to the problem's n dimension. Therefore, the alignment check should be on n (n_aligned), not m (m_aligned). This incorrect pruning might exclude valid and potentially optimal kernel configurations.
| if swap_ab and not m_aligned: | |
| continue | |
| if swap_ab and not n_aligned: | |
| continue |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3068-3173: The cache key used to index _CUTE_DSL_KERNEL_CACHE must
include the device identity to avoid reusing device-specific compiled kernels
across GPUs; modify the construction of cache_key (the tuple currently
containing sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch,
kernel_type, use_tma_store, enable_pdl, out_dtype) to also incorporate the
executing device (derive from kernel_a.device — include device.type and
device.index (or a stable sentinel like -1 if index is None)), and use that
augmented cache_key when reading/writing _CUTE_DSL_KERNEL_CACHE for
compiled_gemm and max_active_clusters so the lookup/store around compiled_gemm
and max_active_clusters becomes device-aware.
- Around line 3176-3192: The kernel assumes row-major memory when swap_ab=True
but launch_out is set to the non-contiguous view out.T; change the launch path
so the kernel receives a contiguous buffer with the expected layout: when
swap_ab is True, allocate a temporary contiguous tensor with the row-major
layout (or call out.clone().contiguous()) into which the kernel will write (this
is the launch_out passed to the kernel), then after the kernel completes copy
the results back into the original out via the appropriate transpose (e.g.,
out.copy_(temp.T)) and free the temp; alternatively, ensure out is originally
allocated with the layout expected by cute.make_ordered_layout so no transpose
view is used. Ensure this change is applied around the launch_out assignment and
kernel invocation that uses swap_ab and interacts with cute.make_ordered_layout.
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py`:
- Around line 1655-1658: The docstring documents a non-existent parameter "sepi"
alongside tTR_rC, tidx, and sC; remove the stale "sepi (cute.Tensor):" entry
from the function's docstring (the block describing tTR_rC, tidx, sC) so the
parameter list matches the actual function signature and leave only real
parameters (e.g., tTR_rC, tidx, sC).
- Around line 2017-2021: Rename the helper function
check_contigous_16B_alignment to check_contiguous_16B_alignment and update all
call sites that invoke it (the three places currently calling
check_contigous_16B_alignment) to use the new name; ensure the function
signature (dtype, is_mode0_major, tensor_shape) remains unchanged so callers
still pass the same arguments and behavior is preserved.
🧹 Nitpick comments (4)
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py (4)
648-648: Nit: prefix unused unpacked variables with_.
bidyandbidzare never referenced. Prefixing them with_silences linter warnings and signals intent.Proposed fix
- bidx, bidy, bidz = cute.arch.block_idx() + bidx, _bidy, _bidz = cute.arch.block_idx()
1677-1719: Unused parametertidx.
tidxis accepted but never referenced insideepilog_gmem_copy_and_partition. If it's kept for API symmetry with the otherepilog_*_copy_and_partitionmethods, consider documenting that intent. Otherwise, remove it.
1721-1829: Unused parametersa_major_modeandb_major_modein_compute_stages.These are passed through from the caller but never referenced in the method body. They appear to be placeholders—consider removing them or adding a comment if they're reserved for future heuristics.
Also note:
num_ab_stage(line 1816) has no lower-bound clamp. If shared memory is under-provisioned for a given configuration, this could yield ≤ 1 stages, breaking double-buffered pipeline semantics. The upstreamcan_implementvalidation likely guards against this in practice, but a defensivemax(num_ab_stage, 2)would be safer.
1923-1947: Unused parametersc_dtypeandc_majorinis_valid_layouts.These parameters are accepted but never checked. If all C layouts are valid, remove them from the signature (and update callers). If C-layout validation is planned, consider adding a
# TODOto track it.
| cache_key = ( | ||
| sf_vec_size, | ||
| mma_tiler_mn, | ||
| cluster_shape_mn, | ||
| swap_ab, | ||
| use_prefetch, | ||
| kernel_type, | ||
| use_tma_store, | ||
| enable_pdl, | ||
| out_dtype, | ||
| ) | ||
|
|
||
| if cache_key not in _CUTE_DSL_KERNEL_CACHE: | ||
| # Create kernel instance | ||
| if kernel_type == "sm103" and Sm103Kernel is not None: | ||
| gemm = Sm103Kernel( # type: ignore[assignment] | ||
| sf_vec_size, | ||
| mma_tiler_mn, | ||
| cluster_shape_mn, | ||
| use_tma_store, | ||
| enable_pdl, | ||
| ) | ||
| else: | ||
| gemm = Sm100BlockScaledPersistentDenseGemmKernel( # type: ignore[assignment] | ||
| sf_vec_size, | ||
| mma_tiler_mn, | ||
| cluster_shape_mn, | ||
| use_prefetch, | ||
| enable_pdl, | ||
| ) | ||
|
|
||
| # Create CuTe pointers for compilation | ||
| a_ptr = make_ptr( | ||
| cutlass.Float4E2M1FN, | ||
| kernel_a.data_ptr(), | ||
| cute.AddressSpace.gmem, | ||
| 32, | ||
| ) | ||
| b_ptr = make_ptr( | ||
| cutlass.Float4E2M1FN, | ||
| kernel_b.data_ptr(), | ||
| cute.AddressSpace.gmem, | ||
| 32, | ||
| ) | ||
| a_sf_ptr = make_ptr( | ||
| cutlass.Float8E4M3FN, | ||
| kernel_a_sf.data_ptr(), | ||
| cute.AddressSpace.gmem, | ||
| 16, | ||
| ) | ||
| b_sf_ptr = make_ptr( | ||
| cutlass.Float8E4M3FN, | ||
| kernel_b_sf.data_ptr(), | ||
| cute.AddressSpace.gmem, | ||
| 16, | ||
| ) | ||
| c_ptr = make_ptr( | ||
| c_cutlass_dtype, out.data_ptr(), cute.AddressSpace.gmem, 16 | ||
| ) | ||
|
|
||
| # Alpha: ensure 1-dim shape [1] for consistent TVM FFI compilation | ||
| if alpha_tensor is not None: | ||
| alpha_compile = ( | ||
| alpha_tensor.reshape(1) | ||
| if alpha_tensor.dim() != 1 | ||
| else alpha_tensor | ||
| ) | ||
| alpha_cute = cute.runtime.from_dlpack(alpha_compile) | ||
| else: | ||
| alpha_cute = cute.runtime.from_dlpack( | ||
| torch.tensor([1.0], dtype=torch.float32, device=a.device) | ||
| ) | ||
|
|
||
| # Get max active clusters | ||
| from flashinfer.cute_dsl.utils import get_max_active_clusters | ||
|
|
||
| max_active_clusters = get_max_active_clusters( | ||
| cluster_shape_mn[0] * cluster_shape_mn[1] | ||
| ) | ||
|
|
||
| # Create a fake stream for compilation | ||
| stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) | ||
|
|
||
| compiled_gemm = cute.compile( | ||
| gemm.wrapper, | ||
| kernel_m, | ||
| kernel_n, | ||
| real_k, | ||
| sf_m, | ||
| sf_n, | ||
| sf_k, | ||
| batch_size, | ||
| a_ptr, | ||
| b_ptr, | ||
| a_sf_ptr, | ||
| b_sf_ptr, | ||
| c_ptr, | ||
| alpha_cute, | ||
| max_active_clusters, | ||
| stream, | ||
| swap_ab, | ||
| options="--opt-level 2 --enable-tvm-ffi", | ||
| ) | ||
|
|
||
| _CUTE_DSL_KERNEL_CACHE[cache_key] = (compiled_gemm, max_active_clusters) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's understand the file structure and locate the relevant code
head -20 flashinfer/gemm/gemm_base.py
wc -l flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 795
🏁 Script executed:
# Look at the specific section mentioned in the review
sed -n '3060,3180p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 4434
🏁 Script executed:
# Find where _CUTE_DSL_KERNEL_CACHE is defined and used
rg "_CUTE_DSL_KERNEL_CACHE" -n flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 343
🏁 Script executed:
# Check if sm_version is used or available in this context
sed -n '3000,3200p' flashinfer/gemm/gemm_base.py | rg -n "sm_version|device|SM"Repository: flashinfer-ai/flashinfer
Length of output: 305
🏁 Script executed:
# Look for the broader function context to understand the function signature
sed -n '3000,3100p' flashinfer/gemm/gemm_base.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3747
🏁 Script executed:
# Check the get_max_active_clusters function
rg "def get_max_active_clusters" -A 10 flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 737
🏁 Script executed:
# Look at kernel_type selection logic and if it's based on SM version
sed -n '3000,3070p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2689
🏁 Script executed:
# Check if enable_pdl and out_dtype are derived from device/SM context
sed -n '2980,3030p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2247
🏁 Script executed:
# Search for any SM/device version detection or multi-device handling
rg "device\." flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for where kernel_type is determined (before the forward call)
sed -n '2900,3000p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 4439
🏁 Script executed:
# Find where sm_version is defined in this method/class context
sed -n '2850,2920p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2423
🏁 Script executed:
# Check the broader class context for where enable_pdl and sm_version come from
sed -n '2780,2850p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2532
🏁 Script executed:
# Look at get_hardware_info to see what it returns
rg "def get_hardware_info" -A 15 flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 1082
🏁 Script executed:
# Check if max_active_clusters is SM-version-dependent
rg "get_max_active_clusters" -B 5 -A 5 flashinfer/cute_dsl/Repository: flashinfer-ai/flashinfer
Length of output: 1263
🏁 Script executed:
# Verify if the cache is truly module-level and shared across function calls
sed -n '2775,2850p' flashinfer/gemm/gemm_base.py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2695
🏁 Script executed:
# Check if there's any device context management or stream handling that might isolate cache usage
sed -n '3100,3180p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2978
🏁 Script executed:
# Look for multi-GPU test cases or documentation
find flashinfer -name "*.py" -type f | xargs grep -l "multi.*gpu\|device.*index\|cuda.device" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 213
🏁 Script executed:
# Check if kernel_type "sm100"/"sm103" is truly the only SM-specific differentiator
sed -n '2810,2830p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 827
🏁 Script executed:
# Check how the runner is instantiated and if it's per-device or shared
sed -n '2778,2790p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 444
🏁 Script executed:
# Look for where _cute_dsl_gemm_fp4_runner is called to understand scope
rg "_cute_dsl_gemm_fp4_runner" -B 3 -A 3 flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 835
🏁 Script executed:
# Verify the exact scope where the class is created and cached
rg "CuteDSLFp4GemmRunner" -B 5 -A 5 flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 1324
🏁 Script executed:
# Check if there's any per-device runner instance management
rg "get_device_properties\|device.*index\|device.*type" flashinfer/gemm/gemm_base.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check how the runner is actually used and if devices could differ
rg "CuteDSLFp4GemmRunner\|cute_dsl.*runner" -B 5 -A 5 flashinfer/gemm/ | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for any device context management in forward method
sed -n '3040,3080p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1457
🏁 Script executed:
# Check if there's ever a case where the function is called with different devices but same SM version
sed -n '3200,3250p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1948
🏁 Script executed:
# Verify what happens when cache is reused - does it handle different devices correctly?
sed -n '3170,3220p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1696
🏁 Script executed:
# Check how TVM FFI stream and compilation work together
rg "make_fake_stream\|use_tvm_ffi_env_stream" -B 3 -A 3 flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if get_hardware_info is truly cached globally and not per-device
sed -n '1,50p' flashinfer/cute_dsl/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 1545
🏁 Script executed:
# Examine the runner factory pattern and whether runners are per-device or shared
sed -n '3250,3300p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1500
🏁 Script executed:
# Check if there's any mechanism to invalidate/update the cache per device
rg "KERNEL_CACHE\|clear\|device" flashinfer/gemm/gemm_base.py | grep -i cacheRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there's any device context isolation in the execution
sed -n '3195,3230p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1051
🏁 Script executed:
# Look at how runners are created and cached at the module level
sed -n '3260,3280p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 586
🏁 Script executed:
# Check if cute.compile produces portable or device-specific binaries
rg "cute.compile\|TVM.*compile" -B 2 -A 2 flashinfer/ | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify if the cache could be shared across different device instantiations
sed -n '3245,3260p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 767
🏁 Script executed:
# Check where the runner is called and understand the call context better
sed -n '3175,3195p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 898
🏁 Script executed:
# Look at the full forward method signature to confirm sm_version availability
sed -n '3030,3050p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 621
🏁 Script executed:
# Check what variables are in scope within the forward method
sed -n '3020,3080p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 2012
🏁 Script executed:
# Confirm that we have access to device info and SM version in the forward method
sed -n '2810,2820p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 465
🏁 Script executed:
# Verify that a.device is accessible in forward and has the needed attributes
sed -n '3035,3045p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 384
🏁 Script executed:
# Check if sm_version is in the closure scope of the forward method
sed -n '2810,2825p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 626
🏁 Script executed:
# Confirm the exact closure and class structure
sed -n '2845,2860p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 623
Cache key must include device identity to prevent cross-device reuse in multi-GPU setups.
_CUTE_DSL_KERNEL_CACHE lacks device-specific information in its cache key. The current 9-tuple key includes kernel parameters but omits device identity; compiled GEMM binaries and max_active_clusters are device-specific. In multi-GPU or heterogeneous SM configurations, this allows incorrect reuse of kernels compiled for different devices. Include device type and index in the cache key:
🔧 Suggested fix (device-aware cache key)
# Cache key for compiled kernel
+ device_key = (a.device.type, a.device.index)
cache_key = (
+ device_key,
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
swap_ab,
use_prefetch,
kernel_type,
use_tma_store,
enable_pdl,
out_dtype,
)🤖 Prompt for AI Agents
In `@flashinfer/gemm/gemm_base.py` around lines 3068 - 3173, The cache key used to
index _CUTE_DSL_KERNEL_CACHE must include the device identity to avoid reusing
device-specific compiled kernels across GPUs; modify the construction of
cache_key (the tuple currently containing sf_vec_size, mma_tiler_mn,
cluster_shape_mn, swap_ab, use_prefetch, kernel_type, use_tma_store, enable_pdl,
out_dtype) to also incorporate the executing device (derive from kernel_a.device
— include device.type and device.index (or a stable sentinel like -1 if index is
None)), and use that augmented cache_key when reading/writing
_CUTE_DSL_KERNEL_CACHE for compiled_gemm and max_active_clusters so the
lookup/store around compiled_gemm and max_active_clusters becomes device-aware.
| # Handle output tensor for swap_ab | ||
| if swap_ab: | ||
| launch_out = out.T | ||
| else: | ||
| launch_out = out | ||
|
|
||
| # Prepare alpha: ensure it is always a 1-dim tensor with shape [1]. | ||
| # mm_fp4 may pass alpha as a 0-dim scalar tensor or a 1-dim tensor; | ||
| # the TVM FFI compiled kernel requires a consistent 1-dim shape. | ||
| if alpha_tensor is None: | ||
| alpha_for_launch = torch.tensor( | ||
| [1.0], dtype=torch.float32, device=a.device | ||
| ) | ||
| elif alpha_tensor.dim() == 0: | ||
| alpha_for_launch = alpha_tensor.unsqueeze(0) | ||
| else: | ||
| alpha_for_launch = alpha_tensor.reshape(1) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the context around lines 3176-3192 in the file
head -3200 flashinfer/gemm/gemm_base.py | tail -100Repository: flashinfer-ai/flashinfer
Length of output: 3737
🏁 Script executed:
# Get full view of the file structure
wc -l flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 101
🏁 Script executed:
# Search for where launch_out is used/passed to kernel
rg "launch_out" flashinfer/gemm/gemm_base.py -A 5 -B 5Repository: flashinfer-ai/flashinfer
Length of output: 875
🏁 Script executed:
# Search for c_major parameter and layout handling
rg "c_major" flashinfer/gemm/gemm_base.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 1375
🏁 Script executed:
# Look for swap_ab logic and how output is handled
rg "swap_ab" flashinfer/gemm/gemm_base.py -B 5 -A 10Repository: flashinfer-ai/flashinfer
Length of output: 10194
🏁 Script executed:
# Search for contiguity checks or .contiguous() calls in the file
rg "contiguous|is_contiguous" flashinfer/gemm/gemm_base.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 2144
🏁 Script executed:
# Look at the kernel compilation code more carefully
sed -n '3100,3200p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 3771
🏁 Script executed:
# Search for cute kernel implementation or documentation about swap_ab
fd "cute_dsl" flashinfer/ --type f | head -20Repository: flashinfer-ai/flashinfer
Length of output: 114
🏁 Script executed:
# Check the cute.compile signature and kernel handling
rg "def compile" flashinfer/ -B 2 -A 5 --type py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1978
🏁 Script executed:
# Look for how output tensor is allocated/prepared before the gemm call
sed -n '3000,3100p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 3781
🏁 Script executed:
# Find the complete compiled_gemm call
sed -n '3190,3210p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 699
🏁 Script executed:
# Search for any contiguity handling for launch_out specifically
rg "launch_out" flashinfer/gemm/gemm_base.py -B 3 -A 15Repository: flashinfer-ai/flashinfer
Length of output: 1370
🏁 Script executed:
# Check if there's a contiguous() call or conversion for transposed tensors in swap_ab handling
rg "swap_ab.*contiguous|contiguous.*swap_ab" flashinfer/gemm/gemm_base.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how kernel_a and kernel_b handle transposition (they use .T too)
sed -n '3050,3080p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1165
🏁 Script executed:
# Search for comments or documentation about swap_ab output handling
rg "swap_ab" flashinfer/gemm/gemm_base.py | grep -i "output\|layout\|contiguous"Repository: flashinfer-ai/flashinfer
Length of output: 114
🏁 Script executed:
# Check if there's kernel implementation that shows how output is handled
find flashinfer/gemm/kernels -name "*.py" -o -name "*.cu" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 319
🏁 Script executed:
# Look at the cute DSL utils to understand make_ptr and how pointers are handled
cat flashinfer/gemm/kernels/cute_dsl_gemm_utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 22577
🏁 Script executed:
# Search for any documentation or tests related to swap_ab output handling
rg "swap_ab.*out|output.*swap" flashinfer/ --type py -B 2 -A 2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 847
🏁 Script executed:
# Look for test files related to swap_ab
fd "test.*gemm\|gemm.*test" flashinfer/ --type f | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for any validation or comments about swap_ab and output layout
rg "swap_ab.*output\|output.*layout.*swap" flashinfer/ -B 2 -A 2 --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if kernel actually needs contiguous output by examining kernel internals
sed -n '1,100p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2353
🏁 Script executed:
# Look for any contiguity enforcement in the gemm call itself
rg "contiguous\|data_ptr" flashinfer/gemm/gemm_base.py | grep -A 2 -B 2 "launch_out\|swap_ab"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Let me check if there's any handling of tensor contiguity for the output tensor
sed -n '3000,3250p' flashinfer/gemm/gemm_base.py | grep -A 5 -B 5 "out = \|\.contiguous\|\.T"Repository: flashinfer-ai/flashinfer
Length of output: 2141
🏁 Script executed:
# Look for how output tensor is created/allocated
rg "out = torch\|torch.empty.*out" flashinfer/gemm/gemm_base.py -B 2 -A 2 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check kernel implementation to understand if it handles non-contiguous output
sed -n '100,500p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -i "layout\|stride\|contiguous\|c_major" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 6676
🏁 Script executed:
# Search for documentation or comments about output layout handling
rg "output.*layout\|layout.*output" flashinfer/gemm/kernels/ -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the kernel call more carefully - what are the actual parameters being passed?
sed -n '3190,3215p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 764
🏁 Script executed:
# Check if kernel_a and kernel_b which also use .T have any special handling
rg "kernel_a.*data_ptr\|kernel_b.*data_ptr" flashinfer/gemm/gemm_base.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for any stride information or layout handling in the kernel call
rg "stride\|layout" flashinfer/gemm/gemm_base.py | grep -i "launch_out\|output\|kernel_[ab]" -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if TVM FFI kernel expects strides or just pointers
rg "tvm\|ffi" flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check actual kernel wrapper to see what parameters it receives
rg "def wrapper\|kernel_m, kernel_n" flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -A 10 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for tests of mm_fp4 with swap_ab to see expected behavior
fd "test" flashinfer/ --type f | xargs rg "mm_fp4\|swap_ab" 2>/dev/null | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check documentation or comments about expected output layout with swap_ab
rg "swap_ab.*True\|swap_ab.*output" flashinfer/ -B 3 -A 3 --type py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how out tensor is created before being passed to the kernel
sed -n '3020,3060p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1405
🏁 Script executed:
# Check if there's any usage of out after the gemm call that would reveal issues
sed -n '3200,3210p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 376
🏁 Script executed:
# Let me check if there's any issue tracking or documentation about this
rg "out\.T\|launch_out" flashinfer/ -B 2 -A 2 --type py | grep -v "\.pyc" | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if the kernel compilation with c_major parameter means it handles layout differences
sed -n '3150,3170p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 697
🏁 Script executed:
# Look for any comments or evidence that transposed tensors work correctly
rg "transpos|non.contiguous" flashinfer/gemm/ -B 2 -A 2 --type pyRepository: flashinfer-ai/flashinfer
Length of output: 4905
🏁 Script executed:
# Check if kernel_a and kernel_b being transposed .T causes issues (they're passed same way as launch_out)
rg "kernel_a\|kernel_b" flashinfer/gemm/gemm_base.py | grep -c "\.T"Repository: flashinfer-ai/flashinfer
Length of output: 69
🏁 Script executed:
# Check the wrapper function signature in the kernel to understand what parameters it receives
sed -n '500,700p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -A 30 "def wrapper"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for c_major or layout handling in the wrapper
rg "c_major\|layout.*swap\|swap.*layout" flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if the kernel actually uses stride information or relies on layout from compilation
rg "stride\|shape\|layout" flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for any bug reports or issues related to swap_ab and output
rg "swap_ab.*bug\|issue.*swap_ab\|error.*swap_ab" flashinfer/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for any assertions or validations about output layout
rg "assert\|ValueError.*layout\|check.*contiguous" flashinfer/gemm/gemm_base.py | grep -i "out\|layout\|swap"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if kernel actually writes to wrong memory locations by looking at how c_major affects kernel behavior
sed -n '1,150p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -i "c_major\|layout\|output"Repository: flashinfer-ai/flashinfer
Length of output: 128
🏁 Script executed:
# Look for the actual kernel wrapper implementation details
wc -l flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 128
🏁 Script executed:
# Find the wrapper function more carefully
rg "wrapper.*=" flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -B 5 -A 20 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for how the kernel handles output writes
rg "c_ref\|output.*write\|global.*store" flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py -B 3 -A 3 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for the actual DSL wrapper definition
sed -n '1800,2192p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3952
🏁 Script executed:
# Search for any actual kernel implementation that writes to output
sed -n '1500,1800p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py | grep -i "epilogue\|store\|c_ref" -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 3186
🏁 Script executed:
# Look for the actual wrapper function definition
sed -n '2050,2192p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 5773
🏁 Script executed:
# Check if there's any real-world validation that this works correctly
rg "test.*fp4\|fp4.*test" flashinfer/ -l --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how c_tensor is constructed with the layout - this is the key
sed -n '2120,2160p' flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.pyRepository: flashinfer-ai/flashinfer
Length of output: 1759
🏁 Script executed:
# Check if cute.make_tensor respects actual strides or assumes them based on order
rg "make_tensor.*layout\|make_ordered_layout" flashinfer/ --type py -B 2 -A 2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for any documentation or warnings about non-contiguous tensors
rg "contiguous\|stride\|layout" flashinfer/gemm/gemm_base.py | grep -B 2 -A 2 "output\|out.*=" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
Confirm kernel layout handling for transposed output. The kernel is JIT-compiled with swap_ab as a parameter and creates a tensor layout via cute.make_ordered_layout based on this flag. When swap_ab=True, it creates a row-major tensor layout (order=(0,1,2)) assuming strides match that layout. However, out.T is a non-contiguous view with different strides than row-major. Since the kernel receives only the base pointer (no stride metadata), it will interpret the memory using incorrect stride assumptions, causing incorrect writes. Either materialize the transposed output as a contiguous buffer and copy results back, or ensure the output tensor is allocated with the correct layout from the start.
🤖 Prompt for AI Agents
In `@flashinfer/gemm/gemm_base.py` around lines 3176 - 3192, The kernel assumes
row-major memory when swap_ab=True but launch_out is set to the non-contiguous
view out.T; change the launch path so the kernel receives a contiguous buffer
with the expected layout: when swap_ab is True, allocate a temporary contiguous
tensor with the row-major layout (or call out.clone().contiguous()) into which
the kernel will write (this is the launch_out passed to the kernel), then after
the kernel completes copy the results back into the original out via the
appropriate transpose (e.g., out.copy_(temp.T)) and free the temp;
alternatively, ensure out is originally allocated with the layout expected by
cute.make_ordered_layout so no transpose view is used. Ensure this change is
applied around the launch_out assignment and kernel invocation that uses swap_ab
and interacts with cute.make_ordered_layout.
| def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): | ||
| major_mode_idx = 0 if is_mode0_major else 1 | ||
| num_major_elements = tensor_shape[major_mode_idx] | ||
| num_contiguous_elements = 16 * 8 // dtype.width | ||
| return num_major_elements % num_contiguous_elements == 0 |
There was a problem hiding this comment.
Typo: check_contigous_16B_alignment → check_contiguous_16B_alignment.
Minor typo in the inner helper name ("contigous" → "contiguous").
Proposed fix
- def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
+ def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape):Update the three call sites on lines 2024–2026 accordingly.
🤖 Prompt for AI Agents
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py` around lines 2017 -
2021, Rename the helper function check_contigous_16B_alignment to
check_contiguous_16B_alignment and update all call sites that invoke it (the
three places currently calling check_contigous_16B_alignment) to use the new
name; ensure the function signature (dtype, is_mode0_major, tensor_shape)
remains unchanged so callers still pass the same arguments and behavior is
preserved.
bkryu
left a comment
There was a problem hiding this comment.
Thanks @nv-yunzheq , left a number a comments
|
/bot run |
bkryu
left a comment
There was a problem hiding this comment.
Thanks for updating. No concerns on my end but will wait for a few more pairs of eyes before approving
Dismissing "request for change" as requested changes have been made
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3392-3397: Docstring inconsistency: replace the underscore form
`cute_dsl` with the exact backend literal `"cute-dsl"` in backticks wherever it
appears in the docstring for the enable_pdl parameter (and the other occurrence
noted around line 3402) so the documentation matches the actual backend name;
update the text referencing enable_pdl to read `\"cute-dsl\"` (in backticks) to
ensure consistent naming across the docstring for the enable_pdl parameter and
related descriptive lines.
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py`:
- Around line 1807-1824: The computed stage counts can become non-positive;
after calculating num_ab_stage and refining num_c_stage (using smem_capacity,
occupancy, mbar_helpers_bytes, c_bytes, ab_bytes_per_stage, c_bytes_per_stage),
clamp them to safe minima (e.g., num_ab_stage = max(1, num_ab_stage) and
num_c_stage = max(2, num_c_stage)) or raise a clear exception if the tile
configuration is invalid; do this just before the return in the function that
computes stages so the pipeline never receives <=0 stages and include a short
error message if you choose to raise.
🧹 Nitpick comments (6)
flashinfer/gemm/gemm_base.py (3)
3194-3197: Avoid allocating a new tensor on every forward call whenalphaisNone.
torch.tensor([1.0], ...)allocates a new CUDA tensor on every invocation. For a hot GEMM path, consider caching the default alpha once (e.g., as an instance attribute or a module-level constant per device).♻️ Suggested approach
+ # Cache a default alpha=1.0 tensor to avoid per-call allocation + _default_alpha_cache = {} + # Prepare alpha: ensure it is always a 1-dim tensor with shape [1]. if alpha_tensor is None: - alpha_for_launch = torch.tensor( - [1.0], dtype=torch.float32, device=a.device - ) + device = a.device + if device not in _default_alpha_cache: + _default_alpha_cache[device] = torch.tensor( + [1.0], dtype=torch.float32, device=device + ) + alpha_for_launch = _default_alpha_cache[device]You could place
_default_alpha_cacheas a class attribute onCuteDSLFp4GemmRunneror a closure variable in_cute_dsl_gemm_fp4_runner.
2939-2950: Hoistget_device_propertiescall outside the loop.
torch.cuda.get_device_properties(a.device).multi_processor_countis called inside nested loops for eachuse_prefetch=Truecandidate. Move it before the loop to avoid repeated lookups.♻️ Suggested change
Add before the
for mma_tiler_mnloop (around line 2905):sm_count = torch.cuda.get_device_properties(a.device).multi_processor_countThen replace lines 2945-2947:
- sm_count = torch.cuda.get_device_properties( - a.device - ).multi_processor_count
2808-2821: Noted: SM103 kernel disabled with clear TODO.The commented-out SM103 import with the explanatory TODO and the explicit
Sm103Kernel = Nonesentinel is clear. Consider tracking this with a GitHub issue so it doesn't get lost.Would you like me to open an issue to track re-enabling the SM103 kernel once the cutlass-dsl package supports
SM103MmaMXF4Op?flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py (3)
1672-1678:tidxparameter is unused.
tidxis accepted but never referenced in the method body. If it's kept for API consistency with siblingepilog_*methods, consider prefixing with underscore (_tidx) to signal intent.
1918-1942:c_dtypeandc_majorparameters are unused.These are accepted but never referenced in the validation logic. If they're placeholders for future constraints, consider adding a brief comment or prefixing with underscore.
1968-1968: Lambda assigned to a variable — prefer adef(Ruff E731).Proposed fix
- _is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + def _is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0
| enable_pdl: bool | ||
| Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl`` | ||
| backend, defaults to ``True``. PDL allows overlapping the tail of one kernel | ||
| with the start of the next for reduced launch latency. This parameter is | ||
| only used by the ``cute_dsl`` backend and is ignored by other backends. | ||
|
|
There was a problem hiding this comment.
Minor inconsistency: cute_dsl vs cute-dsl naming in docstring.
Line 3393 references cute_dsl (underscore) while the backend literal uses "cute-dsl" (hyphen). Consider using backticks with the exact backend string "cute-dsl" consistently to avoid user confusion.
📝 Suggested fix
enable_pdl: bool
- Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl``
+ Whether to enable Programmatic Dependent Launch (PDL) for the ``"cute-dsl"``
backend, defaults to ``True``. PDL allows overlapping the tail of one kernel
with the start of the next for reduced launch latency. This parameter is
- only used by the ``cute_dsl`` backend and is ignored by other backends.
+ only used by the ``"cute-dsl"`` backend and is ignored by other backends.Also on line 3402:
- When cute_dsl backend is used, both a and b should be quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False (same as cudnn/cutlass).
+ When ``"cute-dsl"`` backend is used, both a and b should be quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False (same as cudnn/cutlass).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| enable_pdl: bool | |
| Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl`` | |
| backend, defaults to ``True``. PDL allows overlapping the tail of one kernel | |
| with the start of the next for reduced launch latency. This parameter is | |
| only used by the ``cute_dsl`` backend and is ignored by other backends. | |
| enable_pdl: bool | |
| Whether to enable Programmatic Dependent Launch (PDL) for the ``"cute-dsl"`` | |
| backend, defaults to ``True``. PDL allows overlapping the tail of one kernel | |
| with the start of the next for reduced launch latency. This parameter is | |
| only used by the ``"cute-dsl"`` backend and is ignored by other backends. | |
🤖 Prompt for AI Agents
In `@flashinfer/gemm/gemm_base.py` around lines 3392 - 3397, Docstring
inconsistency: replace the underscore form `cute_dsl` with the exact backend
literal `"cute-dsl"` in backticks wherever it appears in the docstring for the
enable_pdl parameter (and the other occurrence noted around line 3402) so the
documentation matches the actual backend name; update the text referencing
enable_pdl to read `\"cute-dsl\"` (in backticks) to ensure consistent naming
across the docstring for the enable_pdl parameter and related descriptive lines.
| # Calculate A/B/SFA/SFB stages: | ||
| # Start with total smem per CTA (capacity / occupancy) | ||
| # Subtract reserved bytes and initial C stages bytes | ||
| # Divide remaining by bytes needed per A/B/SFA/SFB stage | ||
| num_ab_stage = ( | ||
| smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) | ||
| ) // ab_bytes_per_stage | ||
|
|
||
| # Refine epilogue stages: | ||
| # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes | ||
| # Add remaining unused smem to epilogue | ||
| num_c_stage += ( | ||
| smem_capacity | ||
| - occupancy * ab_bytes_per_stage * num_ab_stage | ||
| - occupancy * (mbar_helpers_bytes + c_bytes) | ||
| ) // (occupancy * c_bytes_per_stage) | ||
|
|
||
| return num_acc_stage, num_ab_stage, num_c_stage |
There was a problem hiding this comment.
No lower-bound guard on computed stage counts.
If smem_capacity / occupancy is too small for the chosen tile configuration, num_ab_stage (Line 1812) could compute to ≤ 0, and the C-stage refinement (Line 1818) could reduce num_c_stage below the initial value of 2. Both would lead to invalid pipeline configurations at runtime.
Consider adding a minimum-stage assertion or early-return:
Proposed guard
num_ab_stage = (
smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)
) // ab_bytes_per_stage
+ assert num_ab_stage >= 2, (
+ f"Not enough shared memory for at least 2 A/B stages "
+ f"(got {num_ab_stage}). Consider reducing tile size or cluster shape."
+ )
# Refine epilogue stages:🤖 Prompt for AI Agents
In `@flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py` around lines 1807 -
1824, The computed stage counts can become non-positive; after calculating
num_ab_stage and refining num_c_stage (using smem_capacity, occupancy,
mbar_helpers_bytes, c_bytes, ab_bytes_per_stage, c_bytes_per_stage), clamp them
to safe minima (e.g., num_ab_stage = max(1, num_ab_stage) and num_c_stage =
max(2, num_c_stage)) or raise a clear exception if the tile configuration is
invalid; do this just before the return in the function that computes stages so
the pipeline never receives <=0 stages and include a short error message if you
choose to raise.
|
[FAILED] Pipeline #43830201: 16/20 passed |
| ----- | ||
| When cudnn/cutlass backend is used, both a and b should quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False. | ||
| When trtllm backend is used, b must be quantized with 128x4 layout and `do_shuffle=True`. a can be quantized with either 128x4 or 8x4 layout (controlled by `use_8x4_sf_layout`) and `do_shuffle=False`. | ||
| When cute_dsl backend is used, both a and b should be quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False (same as cudnn/cutlass). |
There was a problem hiding this comment.
Hi @nv-yunzheq, I have a quick question, I am just wondering if the scale layout is the same, is there any reason for not using this for auto-tuner when backend=auto? Or can it still be considered experimental for now
There was a problem hiding this comment.
Hi @b8zhong, yes it is meant to be experimental at least for this PR; nothing inherently blocking the cute-dsl backend from being considered for autotuning when backend=auto. We'd like to get the kernel in first knowing that it has satisfactory perf, and then perform a more detailed perf study.
We may include the cute-dsl backend as a config that can be autotuned in a followup PR.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2957-2962: The new parameters backend, use_nvfp4, and enable_pdl
are triggering Ruff ARG001/ARG002 unused-argument warnings; to silence them,
explicitly consume or acknowledge the arguments in each function that declares
them (e.g., add a line like "_ = backend, use_nvfp4, enable_pdl" near the top of
the function) or append a per-function noqa (e.g., "# noqa: ARG001") to the def
line; apply this change consistently for every function that declares these
parameters (the overloads/definitions using backend, use_nvfp4, enable_pdl in
this file).
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3564-3567: The non-contiguous view created by launch_out = out.T
when swap_ab is True can break kernel assumptions; change this to produce a
contiguous tensor (e.g., launch_out = out.T.contiguous() or launch_out =
out.transpose(0,1).contiguous()) so launch_out is contiguous in memory before
passing to the kernel; update the swap_ab branch where launch_out, swap_ab, and
out.T are used to ensure the contiguous output is supplied.
- Around line 3444-3455: The cache key tuple named cache_key (constructed from
sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch, kernel_type,
use_tma_store, enable_pdl, out_dtype) is missing any device identity and can
incorrectly reuse kernels across GPUs; update the cache_key to include a
device-unique identifier (e.g., the CUDA device ordinal or a stable device
identifier such as PCI bus id / device UUID or
torch.cuda.get_device_properties(device).name+index) so compiled kernels are
cached per-device. Ensure you retrieve the current device from the same context
where kernels are compiled and append that identifier to the cache_key tuple.
| backend: Literal[ | ||
| "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" | ||
| ] = "auto", # unused | ||
| use_nvfp4: bool = True, | ||
| enable_pdl: bool = True, # unused | ||
| ): |
There was a problem hiding this comment.
Silence Ruff unused-argument warnings for new backend/PDL parameters.
Ruff flags these as unused (ARG001/ARG002). If lint is enforced, consider explicitly consuming them (e.g., _ = backend, enable_pdl) or adding a # noqa: ARG001 on the def line.
💡 Example pattern
def _check_mm_fp4_problem_size(..., backend=..., use_nvfp4=True, enable_pdl=True):
+ _ = backend, enable_pdlAlso applies to: 3017-3022, 3081-3086, 3108-3113, 3121-3153, 3612-3614
🧰 Tools
🪛 Ruff (0.15.1)
[warning] 2957-2957: Unused function argument: backend
(ARG001)
[warning] 2961-2961: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 2957 - 2962, The new parameters
backend, use_nvfp4, and enable_pdl are triggering Ruff ARG001/ARG002
unused-argument warnings; to silence them, explicitly consume or acknowledge the
arguments in each function that declares them (e.g., add a line like "_ =
backend, use_nvfp4, enable_pdl" near the top of the function) or append a
per-function noqa (e.g., "# noqa: ARG001") to the def line; apply this change
consistently for every function that declares these parameters (the
overloads/definitions using backend, use_nvfp4, enable_pdl in this file).
There was a problem hiding this comment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3818-3829: The docstring uses the backend name cute_dsl
inconsistently with the project's canonical backend id "cute-dsl"; update the
occurrences in this docstring (the Notes block and the enable_pdl description)
to use "cute-dsl" (including quotes where other backend names are quoted) so the
naming matches other docs and earlier comments referencing the cute-dsl backend
and the enable_pdl parameter.
- Around line 3458-3576: The cache key built for _CUTE_DSL_MM_FP4_KERNEL_CACHE
(variable cache_key) is missing device identity, causing cross-device kernel
reuse; update the cache_key creation in gemm_base.py to include the current CUDA
device identifier (e.g., torch.cuda.current_device() or equivalent from the
cute/runtime/stream) so compiled_gemm and max_active_clusters are cached
per-device; ensure the same device id is used when looking up and storing
entries in _CUTE_DSL_MM_FP4_KERNEL_CACHE (refer to symbols cache_key,
_CUTE_DSL_MM_FP4_KERNEL_CACHE, compiled_gemm, max_active_clusters).
- Around line 3578-3583: The swap_ab branch assigns a non-contiguous view via
out.T to launch_out which can break downstream kernels; instead ensure
launch_out is a contiguous transposed tensor by replacing the out.T usage with
an explicit transpose followed by making it contiguous (e.g., use
out.transpose(...).contiguous() or out.t().contiguous()) so that launch_out is
contiguous when swap_ab is true; update the block that sets launch_out (the
swap_ab conditional around launch_out and out) accordingly.
- Around line 2961-3151: The Ruff warnings come from unused parameters (backend,
enable_pdl, and similar) introduced in the FP4 requirement helpers; to silence
them, explicitly mark those parameters as deliberately unused by either renaming
to a leading-underscore variant or adding a single-line discard (e.g., del
backend, enable_pdl) at the top of each affected function; apply this change in
_check_mm_fp4_problem_size, _cudnn_gemm_fp4_requirement,
_trtllm_gemm_fp4_requirement, _cutlass_gemm_fp4_requirement, and
_cute_dsl_gemm_fp4_requirement so Ruff no longer reports unused-argument
warnings while keeping the API unchanged.
|
@flashinfer-bot rerun failed |
|
@flashinfer-bot stop |
|
@flashinfer-bot rerun failed |
|
/bot run |
|
[FAILED] Pipeline #44404621: 9/20 passed |
| "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" | ||
| ] = "auto", # unused | ||
| use_nvfp4: bool = True, | ||
| enable_pdl: bool = True, # unused |
There was a problem hiding this comment.
Is this unsed? If so, why has it been added?
There was a problem hiding this comment.
This is the function to check if the given operation is runable. It has to be with the same exact argurment as mm_fp4 function itself. However, some of parameters are not used in this support check function
| "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" | ||
| ] = "auto", # unused | ||
| use_nvfp4: bool = True, | ||
| enable_pdl: bool = True, # unused |
There was a problem hiding this comment.
Why are all these arguments marked as #unsed?
There was a problem hiding this comment.
As mentioned above, it needs to have the same function signature as mm_fp4. However, when checking if cute_dsl backend is viable, we don't need any of these input parameters to determine if it's runable, we only check if cute_dsl is installed or not.
The #unused is for pre-commit check. The pre-commit reformating would reject a function with unused parameters. We mark it to surpass the behavior
📌 Description
Issue #2466
The PR integrate cute_dsl as a new backend for mm_fp4
dense_blockscaled_gemm_sm100.py comes from dense_blockscaled_gemm_persistent.py from TensorRT-LLM.
dense_blockscaled_gemm_sm103.py comes from sm103_dense_blockscaled_gemm_persistent.py from CUTLASS. This file is integrated, but is not currently being used as it requires a pre-released version of
nvidia-cutlass-dsl.gemm_base.py contains main wrapper logic for the mm_fp4 cute dsl gemm kernel
Also upate mm_fp4 unit test and benchmark script to test cute_dsl backend
The performance data:
MMFP4 Benchmark Results
GB200 Non-Autotune
GB200 Autotune
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests